{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from glob import glob\n",
    "import json\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rebel_format(triplets):\n",
    "    \"\"\"\n",
    "    Convert\n",
    "    [['Bruno Santana', 'participant of', '2004 Summer Olympics'],\n",
    "    ['Bruno Santana', 'participant of', '2008 Summer Olympics'],\n",
    "    ['Bruno Santana', 'country of citizenship', 'Brazil']]\n",
    "    to rebel format,\n",
    "    <triplet> Bruno Santana <subj> 2004 Summer Olympics <obj> participant of <subj> 2008 Summer Olympics <obj> participant of <subj> Brazil <obj> country of citizenship\n",
    "    \"\"\"\n",
    "    q = []\n",
    "    for no, triple in enumerate(triplets):\n",
    "        obj = ['<obj>'] + triple[1].split()\n",
    "        subj = ['<subj>'] + triple[2].split()\n",
    "        if no > 0 and triple[0] == triplets[no - 1][0]:\n",
    "            q.extend(subj + obj)\n",
    "        else:\n",
    "            triplet = ['<triplet>'] + triple[0].split()\n",
    "            q.extend(triplet + subj + obj)\n",
    "    return ' '.join(q)\n",
    "\n",
    "def parse_rebel(text):\n",
    "    triplets = []\n",
    "    relation, subject, relation, object_ = '', '', '', ''\n",
    "    text = text.strip()\n",
    "    current = 'x'\n",
    "    for token in text.replace('<s>', '').replace(\"<pad>\", '').replace('</s>', '').split():\n",
    "        if token == '<triplet>':\n",
    "            current = 't'\n",
    "            if relation != '':\n",
    "                triplets.append(\n",
    "                    {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})\n",
    "                relation = ''\n",
    "            subject = ''\n",
    "        elif token == '<subj>':\n",
    "            current = 's'\n",
    "            if relation != '':\n",
    "                triplets.append(\n",
    "                    {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})\n",
    "            object_ = ''\n",
    "        elif token == '<obj>':\n",
    "            current = 'o'\n",
    "            relation = ''\n",
    "        else:\n",
    "            if current == 't':\n",
    "                subject += ' ' + token\n",
    "            elif current == 's':\n",
    "                object_ += ' ' + token\n",
    "            elif current == 'o':\n",
    "                relation += ' ' + token\n",
    "    if subject != '' and relation != '' and object_ != '':\n",
    "        triplets.append({'head': subject.strip(),\n",
    "                         'type': relation.strip(),\n",
    "                         'tail': object_.strip()})\n",
    "    return triplets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['processed-kg-astroawani.jsonl', 'processed-kg-wikipedia.jsonl']"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "files = glob('processed-kg-*.jsonl')\n",
    "files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with open('train.json', 'w') as fopen_jsonl:\n",
    "    for f in files:\n",
    "        with open(f) as fopen:\n",
    "            for l in fopen:\n",
    "                l = json.loads(l)\n",
    "                d = {\"translation\": {\"src\": l['text'], \"tgt\": l['kg'], 'prefix': 'teks ke grafik pengetahuan: '}}\n",
    "                fopen_jsonl.write(f'{json.dumps(d)}\\n')\n",
    "                fopen_jsonl.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"translation\": {\"src\": \"Padah jalin hubungan sulit dengan pekerja sendiri, CEO McDonald's dipecat serta merta\", \"tgt\": \"<triplet> Padah <subj> hubungan sulit <obj> mempunyai <triplet> hubungan sulit <subj> pekerja sendiri <obj> dengan <triplet> Padah <subj> CEO McDonald's <obj> dipecat\", \"prefix\": \"teks ke grafik pengetahuan: \"}}\r\n",
      "{\"translation\": {\"src\": \"CEO tidak boleh menjalin hubungan dengan mana-mana kakitangan.\", \"tgt\": \"<triplet> CEO <subj> kakitangan <obj> tidak boleh menjalin hubungan dengan\", \"prefix\": \"teks ke grafik pengetahuan: \"}}\r\n",
      "{\"translation\": {\"src\": \"SYARIKAT rantaian makanan segera terkemuka dunia, McDonald's Corp mengesahkan telah memecat Ketua Pegawai Eksekutif (CEO), Steve Easterbrook selepas menjalinkan hubungan sulit dengan salah seorang kakitangannya.\", \"tgt\": \"<triplet> <subj> yang telah memecat Steve Easterbrook <obj> mengesahkan <triplet> Steve Easterbrook <subj> hubungan yang tidak sesuai dengan pekerja <obj> telah\", \"prefix\": \"teks ke grafik pengetahuan: \"}}\r\n"
     ]
    }
   ],
   "source": [
    "!head -n 3 train.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "!shuf train.json > shuf-train.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "190102 train.json\r\n"
     ]
    }
   ],
   "source": [
    "!wc -l train.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2532062it [00:08, 295700.23it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "65582"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = []\n",
    "with open('shuf-train.json') as fopen:\n",
    "    for l in tqdm(fopen):\n",
    "        l = json.loads(l)\n",
    "        if len(l['translation']['tgt'].split()) > 100:\n",
    "            data.append(l)\n",
    "        \n",
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'translation': {'src': 'Dameshkaft-e Kalat (, also Romanized as Dameshkaft-e Kalāt; also known as Dameshkaft) is a village in Doshman Ziari Rural District, in the Central District of Kohgiluyeh County, Kohgiluyeh and Boyer-Ahmad Province, Iran. At the 2006 census, its population was 151, in 30 families.',\n",
       "  'tgt': '<triplet> Dameshkaft-e Kalat <subj> Daerah Luar Bandar Doshman Ziari <obj> terletak dalam entiti wilayah pentadbiran <triplet> Daerah Luar Bandar Doshman Ziari <subj> Daerah Tengah <obj> terletak dalam entiti wilayah pentadbiran <subj> Daerah Kohgiluyeh <obj> sebahagian daripada <subj> Iran <obj> negara <triplet> Daerah Tengah <subj> Daerah Kohgiluyeh <obj> terletak dalam entiti wilayah pentadbiran <subj> Wilayah Kohgiluyeh dan Boyer-Ahmad <obj> terletak dalam entiti wilayah pentadbiran <subj> Iran <obj> negara <triplet> Daerah Kohgiluyeh <subj> Daerah Tengah <obj> mengandungi entiti wilayah pentadbiran <subj> Wilayah Kohgiluyeh dan Boyer-Ahmad <obj> terletak dalam entiti wilayah pentadbiran <subj> Iran <obj> negara <triplet> Wilayah Kohgiluyeh dan Boyer-Ahmad <subj> Daerah Kohgiluyeh <obj> mengandungi entiti wilayah pentadbiran <subj> Iran <obj> terletak dalam entiti wilayah pentadbiran <triplet> Iran <subj> Wilayah Kohgiluyeh dan Boyer-Ahmad <obj> mengandungi entiti wilayah pentadbiran',\n",
       "  'prefix': 'teks ke grafik pengetahuan: '}}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[-2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'head': 'Dameshkaft-e Kalat',\n",
       "  'type': 'terletak dalam entiti wilayah pentadbiran',\n",
       "  'tail': 'Daerah Luar Bandar Doshman Ziari'},\n",
       " {'head': 'Daerah Luar Bandar Doshman Ziari',\n",
       "  'type': 'terletak dalam entiti wilayah pentadbiran',\n",
       "  'tail': 'Daerah Tengah'},\n",
       " {'head': 'Daerah Luar Bandar Doshman Ziari',\n",
       "  'type': 'sebahagian daripada',\n",
       "  'tail': 'Daerah Kohgiluyeh'},\n",
       " {'head': 'Daerah Luar Bandar Doshman Ziari',\n",
       "  'type': 'negara',\n",
       "  'tail': 'Iran'},\n",
       " {'head': 'Daerah Tengah',\n",
       "  'type': 'terletak dalam entiti wilayah pentadbiran',\n",
       "  'tail': 'Daerah Kohgiluyeh'},\n",
       " {'head': 'Daerah Tengah',\n",
       "  'type': 'terletak dalam entiti wilayah pentadbiran',\n",
       "  'tail': 'Wilayah Kohgiluyeh dan Boyer-Ahmad'},\n",
       " {'head': 'Daerah Tengah', 'type': 'negara', 'tail': 'Iran'},\n",
       " {'head': 'Daerah Kohgiluyeh',\n",
       "  'type': 'mengandungi entiti wilayah pentadbiran',\n",
       "  'tail': 'Daerah Tengah'},\n",
       " {'head': 'Daerah Kohgiluyeh',\n",
       "  'type': 'terletak dalam entiti wilayah pentadbiran',\n",
       "  'tail': 'Wilayah Kohgiluyeh dan Boyer-Ahmad'},\n",
       " {'head': 'Daerah Kohgiluyeh', 'type': 'negara', 'tail': 'Iran'},\n",
       " {'head': 'Wilayah Kohgiluyeh dan Boyer-Ahmad',\n",
       "  'type': 'mengandungi entiti wilayah pentadbiran',\n",
       "  'tail': 'Daerah Kohgiluyeh'},\n",
       " {'head': 'Wilayah Kohgiluyeh dan Boyer-Ahmad',\n",
       "  'type': 'terletak dalam entiti wilayah pentadbiran',\n",
       "  'tail': 'Iran'},\n",
       " {'head': 'Iran',\n",
       "  'type': 'mengandungi entiti wilayah pentadbiran',\n",
       "  'tail': 'Wilayah Kohgiluyeh dan Boyer-Ahmad'}]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parse_rebel(data[-2]['translation']['tgt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"translation\": {\"src\": \"Vladimir Luxuria (lahir 24 Jun 1965 di Foggia, Apulia) ialah seorang pelakon, penulis, ahli politik dan hos televisyen trans Itali. Luxuria ialah ahli parlimen Parti Kebangkitan Komunis, kepunyaan gabungan Kesatuan yang diketuai oleh Romano Prodi.\\n\\nBeliau adalah ahli Parlimen transgender terbuka pertama di Eropah, dan ahli parlimen transgender terbuka kedua di dunia selepas warga New Zealand Georgina Beyer. Dia kehilangan kerusinya dalam pilihan raya April 2008.\\n\\nDalam pilihan raya umum 2006, Luxuria telah dipilih ke Dewan Perwakilan oleh kawasan Lazio 1 di Rom. Dia kehilangan kerusinya dalam pilihan raya 2008. Selepas persaraan Beyer dan Luxuria, tiada ahli parlimen transgender dilaporkan di dunia, sehingga 2011, apabila Anna Grodzka dipilih ke parlimen Poland.\", \"tgt\": \"<triplet> Vladimir Luxuria <subj> Foggia <obj> tempat kelahiran <subj> hos televisyen <obj> pekerjaan <subj> 24 Jun 1965 <obj> tarikh lahir <subj> Parti Pemulihan Komunis <obj> ahli parti politik <triplet> Parti Pemulihan Komunis <subj> Kesatuan <obj> sebahagian daripada\", \"prefix\": \"teks ke grafik pengetahuan: \"}}\r\n",
      "{\"translation\": {\"src\": \"Adelaide Adrenaline ialah pasukan hoki ais separuh profesional yang berpangkalan di Adelaide, Australia Selatan. Pasukan ini adalah ahli Liga Hoki Ais Australia (AIHL). Pasukan ini diasaskan pada 2008 sebagai Adelaide A untuk menggantikan Adelaide Avalanche yang tidak berfungsi yang melipat pertengahan musim. Pasukan ini bermain di tempat sendiri di IceArenA, yang terletak di pinggir bandar Thebarton.\", \"tgt\": \"<triplet> Adelaide Adrenaline <subj> hoki ais <obj> sukan <subj> 2008 <obj> penubuhan <triplet> IceArenA <subj> Thebarton <obj> lokasi\", \"prefix\": \"teks ke grafik pengetahuan: \"}}\r\n",
      "{\"translation\": {\"src\": \"Josef Vaek (; 12 September 1980 - 7 September 2011) ialah pemain hoki ais profesional Czech. Kali terakhir Vaek bermain untuk Lokomotiv Yaroslavl dari Liga Hoki Kontinental (KHL) dan meninggal dunia dalam nahas pesawat Lokomotiv Yaroslavl pada 7 September 2011. Dia telah bermain tujuh musim dalam Liga Hoki Kebangsaan (NHL) untuk Carolina Hurricanes, Nashville Predators dan New York Islanders sebelum berpindah ke Rusia pada 2008 untuk bermain untuk Yaroslavl.\", \"tgt\": \"<triplet> Josef Vaek <subj> hoki ais <obj> sukan <subj> 12 September 1980 <obj> tarikh lahir <subj> 7 September 2011 <obj> tarikh kematian <subj> Lokomotiv Yaroslavl <obj> ahli pasukan sukan <subj> 7 September 2011 <obj> tarikh kematian <subj> Nashville Predators <obj> ahli pasukan sukan <subj> New York Islanders <obj> ahli pasukan sukan <triplet> Lokomotiv Yaroslavl <subj> Liga Hoki Kontinental <obj> liga <triplet> Lokomotiv Yaroslavl pesawat terhempas <subj> 7 September 2011 <obj> masa tepat\", \"prefix\": \"teks ke grafik pengetahuan: \"}}\r\n"
     ]
    }
   ],
   "source": [
    "!head -n 3 shuf-train.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_file = '/home/husein/ssd3/ctranslate2/en_test.jsonl.translated'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "152836it [00:05, 28895.53it/s]\n"
     ]
    }
   ],
   "source": [
    "with open('test.json', 'w') as fopen_jsonl:\n",
    "    with open(test_file) as fopen:\n",
    "        for l in tqdm(fopen):\n",
    "            l = json.loads(l)\n",
    "            \n",
    "            if not l['triples_ms']:\n",
    "                continue\n",
    "                \n",
    "            if not l['text_ms']:\n",
    "                continue\n",
    "\n",
    "            if len(l['text_ms'].split()) > 400:\n",
    "                continue\n",
    "\n",
    "            triples = []\n",
    "            for t in l['triples_ms']:\n",
    "                if not len(t['head']):\n",
    "                    continue\n",
    "\n",
    "                if not len(t['type']):\n",
    "                    continue\n",
    "\n",
    "                if not len(t['tail']):\n",
    "                    continue\n",
    "\n",
    "                triples.append([t['head'], t['type'], t['tail']])\n",
    "\n",
    "            if not len(triples):\n",
    "                continue\n",
    "\n",
    "            right = rebel_format(triples)\n",
    "            \n",
    "            if len(right.split()) > 200:\n",
    "                continue\n",
    "                \n",
    "            left = l['text_ms'].strip()\n",
    "            left_en = l['text'].strip()\n",
    "            \n",
    "            if len(left) and len(right) and len(left.split()) < 1536 and len(right.split()) < 1536:\n",
    "                d = {\"translation\": {\"src\": left, \"tgt\": right, 'prefix': 'teks ke grafik pengetahuan: '}}\n",
    "                fopen_jsonl.write(f'{json.dumps(d)}\\n')\n",
    "                fopen_jsonl.flush()\n",
    "                \n",
    "            if random.random() > 0.5 and len(left_en) and len(right) and len(left_en.split()) < 1536 and len(right.split()) < 1536:\n",
    "                d = {\"translation\": {\"src\": left_en, \"tgt\": right, 'prefix': 'teks ke grafik pengetahuan: '}}\n",
    "                fopen_jsonl.write(f'{json.dumps(d)}\\n')\n",
    "                fopen_jsonl.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "!shuf test.json > shuf-test.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "!head -n 4000 test.json > test-4k.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-small-malaysian-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset json/default to /home/husein/.cache/huggingface/datasets/json/default-d5f4b4fd406d1504/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0791a1ea14ba4028b024131b196a311d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9677a28f5c8043628fd021136b3be498",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dedccad748a44bf5b8104653bb579725",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset json downloaded and prepared to /home/husein/.cache/huggingface/datasets/json/default-d5f4b4fd406d1504/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8715e514d9b84baaae5375d2d36e969a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "raw_datasets = load_dataset(\n",
    "    'json',\n",
    "    data_files='test-4k.json',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'translation': {'src': 'Istana Beardslee ialah sebuah istana tiga tingkat di Little Falls, New York, Amerika Syarikat, dibina pada tahun 1860 sebagai replika istana Ireland, dan kini digunakan sebagai restoran. Manor itu telah dibina semula dua kali, selepas dibakar oleh kebakaran pada tahun 1919 dan 1989.\\n\\nIstana Beardslee ialah projek Lavina (Pardee) Beardslee yang balu, bermula pada tahun-tahun terakhirnya sekitar akhir 1790-an. Anak lelakinya, Augustus Beardslee, dikreditkan dengan bangunan Istana, walaupun cucu Lavina, Kapten Guy Roosevelt Beardslee (yang dilahirkan di rumah agam), yang menyelia siapnya projek sekitar tahun 1860. Kapten Beardslee menjalankan projek perniagaan yang memperoleh elektrik dan penjana dengan sistem pengedaran memulakan loji kuasa hidro tempatan yang kecil. Syarikat kecil ini berjaya masuk ke dalam Niagara Mohawk Power Corporation.',\n",
       "  'tgt': '<triplet> Beardslee Castle <subj> USA <obj> country <subj> restaurant <obj> contoh <triplet> Little Falls, New York <subj> USA <obj> country',\n",
       "  'prefix': 'teks ke grafik pengetahuan: '}}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_lang = 'src'\n",
    "target_lang = 'tgt'\n",
    "padding = True\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = [ex['prefix'] + ex[source_lang] +\n",
    "              tokenizer.eos_token for ex in examples[\"translation\"]]\n",
    "    targets = [ex[target_lang] + tokenizer.eos_token for ex in examples[\"translation\"]]\n",
    "    model_inputs = tokenizer(\n",
    "        inputs,\n",
    "        max_length=1024,\n",
    "        padding=True,\n",
    "        truncation=True)\n",
    "\n",
    "    # Tokenize targets with the `text_target` keyword argument\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        labels = tokenizer(\n",
    "            targets,\n",
    "            max_length=1024,\n",
    "            padding=True,\n",
    "            truncation=True)\n",
    "\n",
    "    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore\n",
    "    # padding in the loss.\n",
    "    if padding == \"max_length\" and data_args.ignore_pad_token_for_loss:\n",
    "        labels[\"input_ids\"] = [[(l if l != tokenizer.pad_token_id else -100)\n",
    "                                for l in label] for label in labels[\"input_ids\"]]\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    model_inputs.pop('token_type_ids', None)\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "75db62dd83794583949689f745bcfad7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/4000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_dataset = raw_datasets['train'].map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['translation', 'input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 4000\n",
       "})"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['translation', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 4000\n",
       "})"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
